// Copyright 2023 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
//
// This file provides generic interfaces that allow tests to set up test tenants
// without importing the server package (avoiding circular dependencies). This

package main

import (
	"flag"
	"fmt"
	"go/ast"
	"go/format"
	"go/parser"
	"go/token"
	"os"
	"strings"

	"github.com/cockroachdb/errors"
)

var (
	outputFile = flag.String("output", "", "Output file; defaults to stdout.")
	tyName     = flag.String("type", "", "Name of the generated forwarder (default is interface name + 'Forwarder').")
	noTyDecl   = flag.Bool("nodef", false, "Avoid generating the struct and constructor definitions.")
)

func usage() {
	fmt.Println(`usage: genforwarder [flags] <source> <interface>

This program generates a forwarder object for the given Go interface.
Each interface method is implemented by the forwarder such that
calls get forwarded to another object. For example:

    // original interface
    type Foo interface {
      Bar()
    }

    // generated forwarder:
    type FooForwarder struct {
    	fw func() Foo
    }

    func (f *FooForwarder) Bar() {
    	f.fw().Bar()
    }

Note: the generator does not currently work with generic
interfaces.

Flags:`)
	flag.PrintDefaults()
}

func main() {
	flag.Usage = usage
	flag.Parse()

	if flag.NArg() != 2 {
		usage()
		os.Exit(1)
	}
	source := flag.Arg(0)
	itname := flag.Arg(1)

	fs := token.NewFileSet()
	file, err := parser.ParseFile(fs, source, nil, 0)
	noerr(errors.Wrapf(err, "parsing %s", source))

	var it *ast.InterfaceType
	found := false
top:
	for _, decl := range file.Decls {
		gen, ok := decl.(*ast.GenDecl)
		if !ok || gen.Tok != token.TYPE {
			continue
		}
		for _, spec := range gen.Specs {
			ts, ok := spec.(*ast.TypeSpec)
			if !ok {
				continue
			}
			it, ok = ts.Type.(*ast.InterfaceType)
			if !ok {
				continue
			}
			if ts.Name.Name != itname {
				continue
			}
			found = true
			break top
		}
	}
	if !found {
		noerr(errors.Newf("no interface found with name %s", itname))
	}

	w := os.Stdout
	if *outputFile != "" {
		w, err = os.Create(*outputFile)
		noerr(errors.Wrapf(err, "creating %s", *outputFile))
		defer w.Close()
	}
	outy := *tyName
	if outy == "" {
		outy = itname + "Forwarder"
	}

	p := func(format string, args ...interface{}) {
		fmt.Fprintf(w, format, args...)
	}
	p("// Code generated by fwgen; DO NOT EDIT.\n")
	p("// Generated from: %s\n\n", source)
	p("package %s\n\n", file.Name.Name)
	if len(file.Imports) > 0 {
		p("import (")
		for _, imp := range file.Imports {
			if imp.Name != nil {
				p("\t%s %s\n", imp.Name.Name, imp.Path.Value)
				continue
			}
			p("\t%s\n", imp.Path.Value)
		}
		p(")\n\n")
	}
	fwd := fmt.Sprintf("fw%s", itname)
	if !*noTyDecl {
		p("// %s forwards the methods of interface %s.\n", outy, itname)
		p("type %s struct {\n", outy)
		p("\t%s func() %s\n", fwd, itname)
		p("}\n\n")

		cons := "New" + outy
		if outy[0] >= 'a' && outy[0] <= 'z' {
			cons = "new" + string(outy[0]-'a'+'A') + outy[1:]
		}

		p("// %s builds a forwarder for the interface %s.\n", cons, itname)
		p("func %s(fwAccessorFn func() %s) %s {\n", cons, itname, itname)
		p("\treturn &%s{fwAccessorFn}\n}\n\n", outy)
	}
	for _, m := range it.Methods.List {
		ft, ok := m.Type.(*ast.FuncType)
		if !ok {
			continue
		}

		p("// %s is part of the interface %s.\n", m.Names[0].Name, itname)
		p("func (f *%s) %s(", outy, m.Names[0].Name)
		forwardedArgs := []string{}
		ellipsis := false
		argNum := 0
		for i, field := range ft.Params.List {
			if i > 0 {
				p(", ")
			}
			if field.Names == nil {
				arg := fmt.Sprintf("arg%d", argNum)
				argNum++
				forwardedArgs = append(forwardedArgs, arg)
				p("%s", arg)
			} else {
				for j, name := range field.Names {
					if j > 0 {
						p(", ")
					}
					id := name.Name
					if id == "_" {
						id = fmt.Sprintf("arg%d", argNum)
						argNum++
					}
					p("%s", id)
					forwardedArgs = append(forwardedArgs, id)
				}
			}
			var ty ast.Expr
			switch t := field.Type.(type) {
			case *ast.Ellipsis:
				ellipsis = true
				p(" ...")
				ty = t.Elt
			default:
				p(" ")
				ty = t
			}
			noerr(format.Node(w, fs, ty))
		}
		p(")")
		if ft.Results != nil {
			p(" (")
			for i, field := range ft.Results.List {
				if i > 0 {
					p(", ")
				}
				noerr(format.Node(w, fs, field.Type))
			}
			p(")")
		}
		p(" {\n")
		p("\t")
		if ft.Results != nil {
			p("return ")
		}
		p("f.%s(%q).%s(%s", fwd, m.Names[0].Name, m.Names[0].Name, strings.Join(forwardedArgs, ", "))
		if ellipsis {
			p("...")
		}
		p(")\n")
		p("}\n\n")
	}
}

func noerr(err error) {
	if err == nil {
		return
	}
	fmt.Fprintf(os.Stderr, "%s\n", err)
	os.Exit(1)
}
